{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.3\n",
      "IPython 6.2.1\n",
      "\n",
      "tensorflow 1.5.0-rc1\n",
      "numpy 1.12.1\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p tensorflow,numpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using TensorFlow's Dataset API"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TensorFlow provides users with multiple options for providing data to the model. One of the probably most common methods is to define placeholders in the TensorFlow graph and feed the data from the current Python session into the TensorFlow `Session` using the `feed_dict` parameter. Using this approach, a large dataset that does not fit into memory is most conveniently and efficiently stored using NumPy archives as explained in [Chunking an Image Dataset for Minibatch Training using NumPy NPZ Archives](image-data-chunking-npz.ipynb) or HDF5 data base files ([Storing an Image Dataset for Minibatch Training using HDF5](image-data-chunking-hdf5.ipynb)).\n",
    "\n",
    "Another approach, which is often preferred when it comes to computational efficiency, is to do the \"data loading\" directly in the graph using input queues from so-called TFRecords files, which is illustrated in the [Using Input Pipelines to Read Data from TFRecords Files](tfrecords.ipynb) notebook. \n",
    "\n",
    "Now, one could also use inpute input queues to load the data directly on the graph [Using Queue Runners to Feed Images Directly from Disk](file-queues.ipynb). The examples in this Jupyter notebook present an alternative to this manual approach, using TensorFlow's \"new\" Dataset API, which is described in more detail here: https://www.tensorflow.org/programmers_guide/datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. The Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's pretend we have a directory of images containing two subdirectories with images for training, validation, and testing. The following function will create such a dataset of images in JPEG format locally for demonstration purposes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./train-images-idx3-ubyte.gz\n",
      "Extracting ./train-labels-idx1-ubyte.gz\n",
      "Extracting ./t10k-images-idx3-ubyte.gz\n",
      "Extracting ./t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "# Note that executing the following code \n",
    "# cell will download the MNIST dataset\n",
    "# and save all the 60,000 images as separate JPEG\n",
    "# files. This might take a few minutes depending\n",
    "# on your machine.\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# load utilities from ../helper.py\n",
    "import sys\n",
    "sys.path.insert(0, '..') \n",
    "from helper import mnist_export_to_jpg\n",
    "\n",
    "np.random.seed(123)\n",
    "mnist_export_to_jpg(path='./')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `mnist_export_to_jpg` function called above creates 3 directories, mnist_train, mnist_test, and mnist_validation. Note that the names of the subdirectories correspond directly to the class label of the images that are stored under it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mnist_train subdirectories ['9', '0', '7', '6', '1', '8', '4', '3', '2', '5']\n",
      "mnist_valid subdirectories ['9', '0', '7', '6', '1', '8', '4', '3', '2', '5']\n",
      "mnist_test subdirectories ['9', '0', '7', '6', '1', '8', '4', '3', '2', '5']\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "for i in ('train', 'valid', 'test'): \n",
    "    dirs = [d for d in os.listdir('mnist_%s' % i) if not d.startswith('.')]\n",
    "    print('mnist_%s subdirectories' % i, dirs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make sure that the images look okay, the snippet below plots an example image from the subdirectory `mnist_train/9/`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(28, 28)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAEStJREFUeJzt3X2oXPWdx/HPNzfPT5hrbh5I46YWFSWydh10g8uilBRdFC1orKLcxZIINrgF8QH/iaCLD2ytAaWQ2mCE1FpoNCLabZAVFVQyBjV2026D3q3ZxOSGVPOcmOS7f9xJueqd3+9mzsycmXzfL5A7d75z5nzvxM+cufd3zu9n7i4A8YwpuwEA5SD8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCGtvOnc2cOdMXLFjQzl12hOPHjyfrPT09hZ4/dZammRV67qKK9Fb07NOyf/YyDAwMaPfu3aP6wQuF38yulLRSUo+kp939kdTjFyxYoGq1WmSXXWnv3r3J+vTp0ws9/5dfflm3Nm7cuELPXfSNq0hvqW2lfG8TJ05M1lNybzwnTpxI1ou+oTeqUqmM+rENf+w3sx5JT0m6StIFkm4yswsafT4A7VXkd/5LJG1194/d/aikX0u6tjltAWi1IuGfJ+nTYd9vq933FWa2zMyqZlYdHBwssDsAzVQk/CP9UeEbvyi5+yp3r7h7pa+vr8DuADRTkfBvkzR/2PffkrS9WDsA2qVI+DdKOsfMvm1m4yX9UNJLzWkLQKs1PNTn7sfMbLmk/9TQUN9qd/9D0zo7jeSG8g4ePJis54asUkNmBw4cSG47adKkZD03ZHX48OFkvchwW24osMgwZu41z/3cEyZMaHjfnaLQOL+7vyLplSb1AqCNOL0XCIrwA0ERfiAowg8ERfiBoAg/EFRbr+eP6tChQ8l6bqy9yHXpU6ZMSdZz4/THjh1L1qdOndrw9rnLZnP7HjMmfexKjcVPnjw5uW1OK89vaBeO/EBQhB8IivADQRF+ICjCDwRF+IGgGOprg9xQXs7nn3+erKdmku3t7U1umxuSyg1p5Ywd2/j/YkVnHk5dtpsbPs39m3XDUF4OR34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIpx/g6Qm177jDPOaPi59+/fn6znpqDOjWcfOXIkWU+N1eeWb5s2bVqynrssN1XPrQCcq+em9s5dbtwJOr9DAC1B+IGgCD8QFOEHgiL8QFCEHwiK8ANBFRrnN7MBSfskHZd0zN0rzWgqmlZeG15kau3RKLJU9ezZswvtO2ffvn11a7nXJXe9f2oOhW7RjJN8rnD33U14HgBtxMd+IKii4XdJvzez98xsWTMaAtAeRT/2X+bu281slqQNZvZHd39j+ANqbwrLJOmss84quDsAzVLoyO/u22tfd0l6QdIlIzxmlbtX3L3S19dXZHcAmqjh8JvZFDObdvK2pO9L+qhZjQForSIf+2dLeqE2JDJW0q/c/XdN6QpAyzUcfnf/WNLfN7GX01ZuLD03t32R7b/44ovktitXrkzWH3/88WQ99/zjx4+vW7v66quT295zzz3J+qWXXpqsp85ByL2muTUDjh49mqx3w7z+DPUBQRF+ICjCDwRF+IGgCD8QFOEHgmLq7jbIDeWllpKW8lNUp7Z/4oknkts+/PDDyXpuau7cUtaHDh2qW1u3bl2hfeeGIc8999y6tdxQX27q7m4YysvhyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQTHO3wa5JbinTJmSrKemoJak559/vm7t0UcfTW6bG0u/8847k/VKJT1b+6ZNm+rW3n777eS2r776arKeu5x4w4YNdWtFx+lzl/SmLmXuFBz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAoxvnbIHfNe07qmnhJevnllxvedu3atcn6DTfckKznxrtvvfXWurXbb789ue3GjRuT9a1btybrufMrUnLnAXTDOH4OR34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCCo7zm9mqyVdLWmXuy+s3dcr6XlJCyQNSFri7n9tXZvdbcyY9Hvs7t27k/UHH3wwWX/99dfr1u69997ktjfffHOynjtPIDcXwbZt2+rWtmzZktw2Z/78+cn6mWeeWbd24sSJ5LbHjx9P1nPnNxQ9t6MdRnPkf0bSlV+77z5Jr7n7OZJeq30PoItkw+/ub0ja87W7r5W0pnZ7jaTrmtwXgBZr9Hf+2e6+Q5JqX2c1ryUA7dDyP/iZ2TIzq5pZdXBwsNW7AzBKjYZ/p5nNlaTa1131Hujuq9y94u6Vvr6+BncHoNkaDf9Lkvprt/slrW9OOwDaJRt+M3tO0tuSzjOzbWb2I0mPSFpsZn+WtLj2PYAukh3nd/eb6pS+1+ReTlu5td5nzpyZrH/44YfJemr++v7+/ro1Kd+bmSXrO3bsSNaXLl1at/bmm28mt81ZtGhRw9vmfq7cuRnjxo1reN+dgjP8gKAIPxAU4QeCIvxAUIQfCIrwA0ExdXcb5IaFcqc9f/LJJ8l66szJ3KWl27dvL7Tvu+++O1mvVqvJekrudbvmmmuS9dTy4z09Pcltx45NRyM3FNgNuv8nANAQwg8ERfiBoAg/EBThB4Ii/EBQhB8IinH+DpCbJnrq1KnJ+qefflq3dvHFFye3zU3NnZviOndJcEpurHzevHnJeu6S3gkTJpxyTycdO3YsWU+dQyDlpzTvBBz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAoxvnbIDeOP2fOnGT9jjvuSNbXrl1bt/bOO+8kt82dQ5A7DyBn4sSJdWu5cwiuuOKKQvtOPX9uHH/8+PHJeu56/27AkR8IivADQRF+ICjCDwRF+IGgCD8QFOEHgsoOVprZaklXS9rl7gtr9z0gaamkkxPO3+/ur7SqyW7n7sl6bsx5+fLlyfr1119ft/b0008X2vfChQuT9RdffDFZT52DMG3atOS2t912W7Je5Jr53L9JTm4eg25Ywns0R/5nJF05wv0/c/eLav8RfKDLZMPv7m9I2tOGXgC0UZHf+Zeb2YdmttrMZjStIwBt0Wj4fy7pO5IukrRD0k/rPdDMlplZ1cyquTXpALRPQ+F3953uftzdT0j6haRLEo9d5e4Vd6+kFpQE0F4Nhd/M5g779geSPmpOOwDaZTRDfc9JulzSTDPbJmmFpMvN7CJJLmlA0u0t7BFAC1jR8c5TUalUvMh67VHl5ogvMj99bq6B9evXJ+s33nhjsp667j23psBbb72VrOekzmHIXY+fe11yenp6Cm3fqEqlomq1aqN5LGf4AUERfiAowg8ERfiBoAg/EBThB4Lq/vmHu0DRyz/NRjVyM6L9+/cn67nps5955plkPXdJcGpq8BUrViS3zQ1DHz16NFlPDYHmtm3lv0mn4MgPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0Exzt8FcstFHz58uG4ttwR3bmrvDRs2JOs5V111Vd3a4sWLk9vmxuKLXDabO4cgd/5DWZfsNhNHfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IinH+Nih67XfumvnUNNOpcwAk6cknn0zWc9vPmzcvWb/rrruS9ZTcWHqunup94sSJyW1beY5Bp+DIDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBZcf5zWy+pGclzZF0QtIqd19pZr2Snpe0QNKApCXu/tfWtdq9cteG5+zduzdZ7+3trVt76KGHktt+8MEHDfV00i233JKsp5bhzv1c06dPb6ink4qcX5Gbtz+3hHc3nAcwmiP/MUl3ufv5kv5R0o/N7AJJ90l6zd3PkfRa7XsAXSIbfnff4e6barf3SdoiaZ6kayWtqT1sjaTrWtUkgOY7pd/5zWyBpO9KelfSbHffIQ29QUia1ezmALTOqMNvZlMl/VbST9w9/cvaV7dbZmZVM6sODg420iOAFhhV+M1snIaCv9bd19Xu3mlmc2v1uZJ2jbStu69y94q7V/r6+prRM4AmyIbfhv5k+ktJW9z98WGllyT11273S1rf/PYAtMpoLum9TNKtkjab2fu1++6X9Iik35jZjyT9RdINrWmx+xUd9pkxY0ayPjAwULf21FNPFdr3eeedl6z39/cn66lhzqJDebnLjVOX7eYu2c1Nlz5mTPefIpMNv7u/JanegOn3mtsOgHbp/rcvAA0h/EBQhB8IivADQRF+ICjCDwTF1N1tkBvnP3jwYLI+YcKEZP2xxx6rW/vss8+S2+YsXbo0WT///POT9dTPNnny5IZ6OqnIJbu5Jboj4MgPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0Exzt8BJk2alKxv3rw5WV+3bl2ynjJnzpxkfcmSJQ0/t5Qey9+1a8TJn/5m1qz0tJC58x8OHTpUt5Z7zXNy07GfLlN3AzgNEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIzzt0GR+eWl/Nz7e/bsOeWeTlq0aFGyPn/+/IafOyc3jl9U0bH8lG4Yx8/hyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQWXH+c1svqRnJc2RdELSKndfaWYPSFoqabD20Pvd/ZVWNdrNcuP4Bw4cSNbPPvvshp//wgsvTG67YsWKZD2n6DX5KM9oTvI5Jukud99kZtMkvWdmG2q1n7n7f7SuPQCtkg2/u++QtKN2e5+ZbZE0r9WNAWitU/qd38wWSPqupHdrdy03sw/NbLWZzaizzTIzq5pZdXBwcKSHACjBqMNvZlMl/VbST9x9r6SfS/qOpIs09MngpyNt5+6r3L3i7pW+vr4mtAygGUYVfjMbp6Hgr3X3dZLk7jvd/bi7n5D0C0mXtK5NAM2WDb8NLYX6S0lb3P3xYffPHfawH0j6qPntAWgVyy1VbGb/JOlNSZs1NNQnSfdLuklDH/ld0oCk22t/HKyrUql4tVot2PLpZ+fOncn67Nmzk/UjR47UrY0fPz65bW6Z6/379yfrU6dOTdbRXpVKRdVqdVRrl4/mr/1vSRrpyRjTB7oYZ/gBQRF+ICjCDwRF+IGgCD8QFOEHgmLq7g7Q29ubrOfOxUgtVZ3bNnWOgCSNGcPx4XTFvywQFOEHgiL8QFCEHwiK8ANBEX4gKMIPBJW9nr+pOzMblPS/w+6aKWl32xo4NZ3aW6f2JdFbo5rZ29+5+6jmy2tr+L+xc7Oqu1dKayChU3vr1L4kemtUWb3xsR8IivADQZUd/lUl7z+lU3vr1L4kemtUKb2V+js/gPKUfeQHUJJSwm9mV5rZn8xsq5ndV0YP9ZjZgJltNrP3zazUecZry6DtMrOPht3Xa2YbzOzPta8jLpNWUm8PmNn/1V67983sX0rqbb6Z/ZeZbTGzP5jZv9XuL/W1S/RVyuvW9o/9ZtYj6X8kLZa0TdJGSTe5+3+3tZE6zGxAUsXdSx8TNrN/lrRf0rPuvrB232OS9rj7I7U3zhnufm+H9PaApP1lr9xcW1Bm7vCVpSVdJ+lfVeJrl+hriUp43co48l8iaau7f+zuRyX9WtK1JfTR8dz9DUl7vnb3tZLW1G6v0dD/PG1Xp7eO4O473H1T7fY+SSdXli71tUv0VYoywj9P0qfDvt+mzlry2yX93szeM7NlZTczgtknV0aqfZ1Vcj9fl125uZ2+trJ0x7x2jax43WxlhH+k1X86acjhMnf/B0lXSfpx7eMtRmdUKze3ywgrS3eERle8brYywr9N0vxh339L0vYS+hiRu2+vfd0l6QV13urDO08uklr7uqvkfv6mk1ZuHmllaXXAa9dJK16XEf6Nks4xs2+b2XhJP5T0Ugl9fIOZTan9IUZmNkXS99V5qw+/JKm/drtf0voSe/mKTlm5ud7K0ir5teu0Fa9LOcmnNpTxhKQeSavd/d/b3sQIzOxsDR3tpaGZjX9VZm9m9pykyzV01ddOSSskvSjpN5LOkvQXSTe4e9v/8Fant8t1iis3t6i3eitLv6sSX7tmrnjdlH44ww+IiTP8gKAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8E9f/dxj83nV7WNAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x104ce8c88>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.image as mpimg\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "\n",
    "some_img = os.path.join('./mnist_train/9/', os.listdir('./mnist_train/9/')[0])\n",
    "\n",
    "img = mpimg.imread(some_img)\n",
    "print(img.shape)\n",
    "plt.imshow(img, cmap='binary');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: The JPEG format introduces a few artifacts that we can see in the image above. In this case, we use JPEG instead of PNG. Here, JPEG is used for demonstration purposes since that's still format many image datasets are stored in."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. TensorFlow Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TensorFlow's new Dataset API introduces several new concepts, with the goal of feeding data to the computational graph more efficiently (compared to other approaches like the session's `feed_dicts` that insert values into the \"placeholders\" of a graph). In this context, a `Dataset` can be understood as a collection of \"elements,\" where each element is composed of multiple tensors. In turn, each element might contain several objects of different types. \n",
    "\n",
    "But before we jump further into creating \"datasets,\" let's generate Python lists of the image paths and image labels from the MNIST dataset that we prepared in the first section. \n",
    "\n",
    "Note that TensorFlow datasets can be shuffled, and we will use the `shuffle` operation later on, however, the shuffling operation only shuffles a subset of the whole dataset. The number of elements being shuffled is determined by the `buffer_size` parameter, which also determines how many elements are loaded into memory for shuffling. And since we typically cannot load all elements into memory, it is a good idea to shuffle our dataset, which might be organized by class labels, upfront.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import random\n",
    "\n",
    "random.seed(123)\n",
    "\n",
    "train_paths = glob.glob('mnist_train/**/*.jpg', recursive=True)\n",
    "train_labels = [int(s.split('/')[1]) for s in train_paths]\n",
    "tmp = list(zip(train_paths, train_labels))\n",
    "random.shuffle(tmp)\n",
    "train_paths, train_labels = zip(*tmp)\n",
    "\n",
    "valid_paths = glob.glob('mnist_valid/**/*.jpg', recursive=True)\n",
    "valid_labels = [int(s.split('/')[1]) for s in valid_paths]\n",
    "tmp = list(zip(valid_paths, valid_labels))\n",
    "random.shuffle(tmp)\n",
    "valid_paths, valid_labels = zip(*tmp)\n",
    "\n",
    "test_paths = glob.glob('mnist_test/**/*.jpg', recursive=True)\n",
    "test_labels = [int(s.split('/')[1]) for s in test_paths]\n",
    "tmp = list(zip(test_paths, test_labels))\n",
    "random.shuffle(tmp)\n",
    "test_paths, test_labels = zip(*tmp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After creating Python lists storing the image paths and corresponding class labels, we can convert these to TensorFlow tensors and use the `Dataset`'s `from_tensor_slice` function to create dataset objects. Based on those `Dataset` objects, we can then create an `Iterator` object to fetch data examples form the \"dataset;\" the \"fetching\" operation is called `.get_next()`. Finally, we need initializers for the different datasets (train, validation, and test), if we want to use them in a TensorFlow session."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "data_g1 = tf.Graph()\n",
    "\n",
    "with data_g1.as_default():\n",
    "    \n",
    "    # setup tensor elements for the dataset\n",
    "    tf_train_paths = tf.constant(train_paths)\n",
    "    tf_train_labels = tf.constant(train_labels)\n",
    "    tf_valid_paths = tf.constant(valid_paths)\n",
    "    tf_valid_labels = tf.constant(valid_labels)\n",
    "    tf_test_paths = tf.constant(test_paths)\n",
    "    tf_test_labels = tf.constant(test_labels)\n",
    "    \n",
    "    \n",
    "    # construct datasets from tf.Tensor objects\n",
    "    train_dataset = tf.data.Dataset.from_tensor_slices((tf_train_paths,\n",
    "                                                        tf_train_labels)) \n",
    "    valid_dataset = tf.data.Dataset.from_tensor_slices((tf_valid_paths,\n",
    "                                                        tf_valid_labels)) \n",
    "    test_dataset = tf.data.Dataset.from_tensor_slices((tf_test_paths,\n",
    "                                                       tf_test_labels)) \n",
    "    \n",
    "    # initializing iterator to extract elements from the dataset\n",
    "    #   Note: only need 1 iterator, since validation and test \n",
    "    #   datasets have the same image shapes\n",
    "    iterator = tf.data.Iterator.from_structure(train_dataset.output_types,\n",
    "                                               train_dataset.output_shapes)\n",
    "    \n",
    "    # define op that fetches the next element from the iterator\n",
    "    next_element = iterator.get_next()\n",
    "    \n",
    "    # define initializers for the iterator\n",
    "    train_iter_init = iterator.make_initializer(train_dataset)\n",
    "    valid_iter_init = iterator.make_initializer(valid_dataset)\n",
    "    test_iter_init = iterator.make_initializer(test_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, after this quite elaborate setup, let us fetch data examples in a TensorFlow session to make sure that our setup works as intended:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fetch element #1 from training dataset:\n",
      "(b'mnist_train/8/20750.jpg', 8)\n",
      "Fetch element #2 from training dataset:\n",
      "(b'mnist_train/7/12707.jpg', 7)\n",
      "Fetch element #3 from training dataset:\n",
      "(b'mnist_train/2/29033.jpg', 2)\n",
      "\n",
      "Fetch element #1 from validation dataset:\n",
      "(b'mnist_valid/9/46835.jpg', 9)\n",
      "Fetch element #2 from validation dataset:\n",
      "(b'mnist_valid/8/45313.jpg', 8)\n",
      "Fetch element #3 from validation dataset:\n",
      "(b'mnist_valid/3/48364.jpg', 3)\n",
      "\n",
      "Fetch element #1 from test dataset:\n",
      "(b'mnist_test/7/09152.jpg', 7)\n",
      "Fetch element #2 from test dataset:\n",
      "(b'mnist_test/5/03120.jpg', 5)\n",
      "Fetch element #3 from test dataset:\n",
      "(b'mnist_test/8/09844.jpg', 8)\n"
     ]
    }
   ],
   "source": [
    "with tf.Session(graph=data_g1) as sess:\n",
    "\n",
    "    sess.run(train_iter_init)\n",
    "    for i in range(3):\n",
    "        print('Fetch element #%d from training dataset:' % (i+1))\n",
    "        ele = sess.run(next_element)\n",
    "        print(ele)\n",
    "    \n",
    "    print()\n",
    "    sess.run(valid_iter_init)\n",
    "    for i in range(3):\n",
    "        print('Fetch element #%d from validation dataset:' % (i+1))\n",
    "        ele = sess.run(next_element)\n",
    "        print(ele)\n",
    "        \n",
    "    print()\n",
    "    sess.run(test_iter_init)\n",
    "    for i in range(3):\n",
    "        print('Fetch element #%d from test dataset:' % (i+1))\n",
    "        ele = sess.run(next_element)\n",
    "        print(ele)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the procedure we defined only yields the file paths and class lables that we previously stored in Python lists. However, this is a good sanity check before we jump to the next step and read and prepare image files for machine learning or deep learning algorithms."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Preprocessing images"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `Dataset` API allows us to provide custom preprocessing functions that we can apply to our images. Since we are working with a relatively simple dataset such as MNIST, our preprocessing simply consists of reading in the image from a JPEG file, casting it to the correct type, normalizing the pixels to [0, 1] range, and perform a one-hot encoding on the labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_image_jpg_onehot(path, label):\n",
    "    str_tensor = tf.read_file(path)\n",
    "    decoded_image = tf.image.decode_jpeg(str_tensor,\n",
    "                                         channels=1,\n",
    "                                         fancy_upscaling=False)\n",
    "    # normalize to [0, 1] range\n",
    "    decoded_image = tf.cast(decoded_image, tf.float32)\n",
    "    decoded_image = decoded_image / 255.\n",
    "    # depth=10 because we have 10 mnist class labels\n",
    "    onehot_label = tf.one_hot(label, depth=10)\n",
    "    return decoded_image, onehot_label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition to using the `map` function (shown in the next code snippet), we will also shuffle the dataset and yield batches (instead of single training examples). For simplicity, we omit the test dataset, but its preparation is analogous to the validation dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 128\n",
    "\n",
    "def datareader():\n",
    "    tf_train_paths = tf.constant(train_paths)\n",
    "    tf_train_labels = tf.constant(train_labels)\n",
    "    tf_valid_paths = tf.constant(valid_paths)\n",
    "    tf_valid_labels = tf.constant(valid_labels)\n",
    "    \n",
    "    train_dataset = tf.data.Dataset.from_tensor_slices((tf_train_paths,\n",
    "                                                        tf_train_labels)) \n",
    "    valid_dataset = tf.data.Dataset.from_tensor_slices((tf_valid_paths,\n",
    "                                                        tf_valid_labels)) \n",
    "    \n",
    "    ############################################################\n",
    "    ## Custom data transformation; \n",
    "    #  here: image reading, shuffling, batching\n",
    "    train_dataset = train_dataset.map(read_image_jpg_onehot,\n",
    "                                      num_parallel_calls=4)\n",
    "    train_dataset = train_dataset.shuffle(buffer_size=1000)\n",
    "    train_dataset = train_dataset.batch(BATCH_SIZE)\n",
    "    \n",
    "    valid_dataset = valid_dataset.map(read_image_jpg_onehot,\n",
    "                                      num_parallel_calls=4)\n",
    "    valid_dataset = valid_dataset.batch(BATCH_SIZE)\n",
    "    ############################################################\n",
    "\n",
    "    iterator = tf.data.Iterator.from_structure(train_dataset.output_types,\n",
    "                                               train_dataset.output_shapes)\n",
    "\n",
    "    next_element = iterator.get_next(name='next_element')\n",
    "    \n",
    "    train_iter_init = iterator.make_initializer(train_dataset,\n",
    "                                                name='train_iter_init')\n",
    "    valid_iter_init = iterator.make_initializer(valid_dataset,\n",
    "                                                name='valid_iter_init')\n",
    "    \n",
    "    return next_element\n",
    "\n",
    "\n",
    "data_g2 = tf.Graph()\n",
    "with data_g2.as_default():\n",
    "    datareader()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fetch batch #1 from training dataset:\n",
      "(128, 28, 28, 1) (128, 10)\n",
      "Fetch batch #2 from training dataset:\n",
      "(128, 28, 28, 1) (128, 10)\n",
      "Fetch batch #3 from training dataset:\n",
      "(128, 28, 28, 1) (128, 10)\n",
      "\n",
      "Fetch batch #1 from validation dataset:\n",
      "(128, 28, 28, 1) (128, 10)\n",
      "Fetch batch #2 from validation dataset:\n",
      "(128, 28, 28, 1) (128, 10)\n",
      "Fetch batch #3 from validation dataset:\n",
      "(128, 28, 28, 1) (128, 10)\n"
     ]
    }
   ],
   "source": [
    "with tf.Session(graph=data_g2) as sess:\n",
    "\n",
    "    sess.run('train_iter_init')\n",
    "    for i in range(3):\n",
    "        print('Fetch batch #%d from training dataset:' % (i+1))\n",
    "        images, labels = sess.run(['next_element:0', 'next_element:1'])\n",
    "        print(images.shape, labels.shape)\n",
    "        \n",
    "    print()\n",
    "    sess.run('valid_iter_init')\n",
    "    for i in range(3):\n",
    "        print('Fetch batch #%d from validation dataset:' % (i+1))\n",
    "        images, labels = sess.run(['next_element:0', 'next_element:1'])\n",
    "        print(images.shape, labels.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Using the Dataset API to train a neural network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The next code example illustrates how we can use the `Dataset` API to train a simple 2-layer multilayer perceptron. Note that we previously defined as `datareader()` function so that we can insert here instead of copy & pasting the code into the graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "learning_rate = 0.1\n",
    "n_epochs = 15\n",
    "n_iter = n_epochs * (len(train_paths) // BATCH_SIZE)\n",
    "\n",
    "# Architecture\n",
    "n_hidden_1 = 128\n",
    "n_hidden_2 = 256\n",
    "height, width = 28, 28\n",
    "n_classes = 10\n",
    "\n",
    "\n",
    "\n",
    "##########################\n",
    "### GRAPH DEFINITION\n",
    "##########################\n",
    "\n",
    "g = tf.Graph()\n",
    "with g.as_default():\n",
    "    \n",
    "    tf.set_random_seed(123)\n",
    "\n",
    "    # Input data\n",
    "    next_element = datareader()\n",
    "    \n",
    "    tf_images = tf.placeholder_with_default(next_element[0],\n",
    "                                            shape=[None, 28, 28, 1], \n",
    "                                            name='images')\n",
    "    tf_labels = tf.placeholder_with_default(next_element[1], \n",
    "                                            shape=[None, 10], \n",
    "                                            name='labels')\n",
    "    \n",
    "    tf_images = tf.reshape(tf_images, (tf.shape(tf_images)[0], 784))\n",
    "    tf_images = tf.cast(tf_images, dtype=tf.float32)\n",
    "\n",
    "    # Model parameters\n",
    "    weights = {\n",
    "        'h1': tf.Variable(tf.truncated_normal([height*width, n_hidden_1], stddev=0.1)),\n",
    "        'h2': tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2], stddev=0.1)),\n",
    "        'out': tf.Variable(tf.truncated_normal([n_hidden_2, n_classes], stddev=0.1))\n",
    "    }\n",
    "    biases = {\n",
    "        'b1': tf.Variable(tf.zeros([n_hidden_1])),\n",
    "        'b2': tf.Variable(tf.zeros([n_hidden_2])),\n",
    "        'out': tf.Variable(tf.zeros([n_classes]))\n",
    "    }\n",
    "\n",
    "    # Multilayer perceptron\n",
    "    layer_1 = tf.add(tf.matmul(tf_images, weights['h1']), biases['b1'])\n",
    "    layer_1 = tf.nn.relu(layer_1)\n",
    "    layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])\n",
    "    layer_2 = tf.nn.relu(layer_2)\n",
    "    out_layer = tf.matmul(layer_2, weights['out']) + biases['out']\n",
    "\n",
    "    # Loss and optimizer\n",
    "    loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=out_layer, labels=tf_labels)\n",
    "    cost = tf.reduce_mean(loss, name='cost')\n",
    "    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\n",
    "    train = optimizer.minimize(cost, name='train')\n",
    "\n",
    "    # Prediction\n",
    "    prediction = tf.argmax(out_layer, 1, name='prediction')\n",
    "    correct_prediction = tf.equal(tf.argmax(tf_labels, 1), tf.argmax(out_layer, 1))\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, after defining the computational graph, the next code snippet will train the multilayer perceptron."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001 | AvgCost: 0.007\n",
      "Epoch: 002 | AvgCost: 0.475\n",
      "Epoch: 003 | AvgCost: 0.237\n",
      "Epoch: 004 | AvgCost: 0.183\n",
      "Epoch: 005 | AvgCost: 0.151\n",
      "Epoch: 006 | AvgCost: 0.128\n",
      "Epoch: 007 | AvgCost: 0.111\n",
      "Epoch: 008 | AvgCost: 0.097\n",
      "Epoch: 009 | AvgCost: 0.086\n",
      "Epoch: 010 | AvgCost: 0.077\n",
      "Epoch: 011 | AvgCost: 0.069\n",
      "Epoch: 012 | AvgCost: 0.062\n",
      "Epoch: 013 | AvgCost: 0.056\n",
      "Epoch: 014 | AvgCost: 0.050\n",
      "Epoch: 015 | AvgCost: 0.045\n"
     ]
    }
   ],
   "source": [
    "with tf.Session(graph=g) as sess:\n",
    "    sess.run('train_iter_init')\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    saver0 = tf.train.Saver()\n",
    "    \n",
    "    avg_cost = 0.\n",
    "    iter_per_epoch = n_iter // n_epochs\n",
    "    epoch = 0\n",
    "\n",
    "    for i in range(n_iter):\n",
    "        \n",
    "        _, cost = sess.run(['train', 'cost:0'])\n",
    "        avg_cost += cost\n",
    "        \n",
    "        if not i % iter_per_epoch:\n",
    "            epoch += 1\n",
    "            avg_cost /= iter_per_epoch\n",
    "            print(\"Epoch: %03d | AvgCost: %.3f\" % (epoch, avg_cost))\n",
    "            avg_cost = 0.\n",
    "            sess.run('train_iter_init')\n",
    "    \n",
    "    saver0.save(sess, save_path='./mlp')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To demonstrate how we can feed new data points to the network that are not part of the training queue, let's use the test dataset and load the images into Python and pass it to the graph using a `feed_dict`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001 | AvgCost: 0.007\n",
      "Epoch: 002 | AvgCost: 0.475\n",
      "Epoch: 003 | AvgCost: 0.237\n",
      "Epoch: 004 | AvgCost: 0.183\n",
      "Epoch: 005 | AvgCost: 0.151\n",
      "Epoch: 006 | AvgCost: 0.128\n",
      "Epoch: 007 | AvgCost: 0.111\n",
      "Epoch: 008 | AvgCost: 0.097\n",
      "Epoch: 009 | AvgCost: 0.086\n",
      "Epoch: 010 | AvgCost: 0.077\n",
      "Epoch: 011 | AvgCost: 0.069\n",
      "Epoch: 012 | AvgCost: 0.062\n",
      "Epoch: 013 | AvgCost: 0.056\n",
      "Epoch: 014 | AvgCost: 0.050\n",
      "Epoch: 015 | AvgCost: 0.045\n"
     ]
    }
   ],
   "source": [
    "with tf.Session(graph=g) as sess:\n",
    "    sess.run('train_iter_init')\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    saver0 = tf.train.Saver()\n",
    "    \n",
    "    avg_cost = 0.\n",
    "    iter_per_epoch = n_iter // n_epochs\n",
    "    epoch = 0\n",
    "\n",
    "    for i in range(n_iter):\n",
    "        \n",
    "        _, cost = sess.run(['train', 'cost:0'])\n",
    "        avg_cost += cost\n",
    "        \n",
    "        if not i % iter_per_epoch:\n",
    "            epoch += 1\n",
    "            avg_cost /= iter_per_epoch\n",
    "            print(\"Epoch: %03d | AvgCost: %.3f\" % (epoch, avg_cost))\n",
    "            avg_cost = 0.\n",
    "            sess.run('train_iter_init')\n",
    "    \n",
    "    saver0.save(sess, save_path='./mlp')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Feeding new datapoints through placeholders"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To demonstrate how we can feed new data points to the network that are not part of the training queue, let's use the test dataset and load the images into Python and pass it to the graph using a `feed_dict`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./mlp\n",
      "Test accuracy: 96.9%\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.image as mpimg\n",
    "import numpy as np\n",
    "import glob\n",
    "\n",
    "\n",
    "img_paths = np.array([p for p in glob.iglob('mnist_test/*/*.jpg')])\n",
    "labels = np.array([int(path.split('/')[1]) for path in img_paths])\n",
    "\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    \n",
    "    saver1 = tf.train.import_meta_graph('./mlp.meta')\n",
    "    saver1.restore(sess, save_path='./mlp')\n",
    "    \n",
    "    num_correct = 0\n",
    "    cnt = 0\n",
    "    for path, lab in zip(img_paths, labels):\n",
    "        cnt += 1\n",
    "        image = mpimg.imread(path)\n",
    "        image = image.reshape(1, 28, 28, 1)\n",
    "        \n",
    "        pred = sess.run('prediction:0', \n",
    "                         feed_dict={'images:0': image})\n",
    "\n",
    "        num_correct += int(lab == pred[0])\n",
    "    acc = num_correct / cnt * 100\n",
    "\n",
    "print('Test accuracy: %.1f%%' % acc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
